/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *  * Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 *  * Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *  * Neither the name of NVIDIA CORPORATION nor the names of its
 *    contributors may be used to endorse or promote products derived
 *    from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include <cusolverDn.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#include "mmio.h"

/* avoid Windows warnings (for example: strcpy, fscanf, etc.) */
#if defined(_WIN32)
#define _CRT_SECURE_NO_WARNINGS
#endif

/* various __inline__ __device__  function to initialize a T_ELEM */
template <typename T_ELEM>
__inline__ T_ELEM cuGet(int);
template <>
__inline__ float cuGet<float>(int x) {
  return float(x);
}

template <>
__inline__ double cuGet<double>(int x) {
  return double(x);
}

template <>
__inline__ cuComplex cuGet<cuComplex>(int x) {
  return (make_cuComplex(float(x), 0.0f));
}

template <>
__inline__ cuDoubleComplex cuGet<cuDoubleComplex>(int x) {
  return (make_cuDoubleComplex(double(x), 0.0));
}

template <typename T_ELEM>
__inline__ T_ELEM cuGet(int, int);
template <>
__inline__ float cuGet<float>(int x, int y) {
  return float(x);
}

template <>
__inline__ double cuGet<double>(int x, int y) {
  return double(x);
}

template <>
__inline__ cuComplex cuGet<cuComplex>(int x, int y) {
  return make_cuComplex(float(x), float(y));
}

template <>
__inline__ cuDoubleComplex cuGet<cuDoubleComplex>(int x, int y) {
  return (make_cuDoubleComplex(double(x), double(y)));
}

template <typename T_ELEM>
__inline__ T_ELEM cuGet(float);
template <>
__inline__ float cuGet<float>(float x) {
  return float(x);
}

template <>
__inline__ double cuGet<double>(float x) {
  return double(x);
}

template <>
__inline__ cuComplex cuGet<cuComplex>(float x) {
  return (make_cuComplex(float(x), 0.0f));
}

template <>
__inline__ cuDoubleComplex cuGet<cuDoubleComplex>(float x) {
  return (make_cuDoubleComplex(double(x), 0.0));
}

template <typename T_ELEM>
__inline__ T_ELEM cuGet(float, float);
template <>
__inline__ float cuGet<float>(float x, float y) {
  return float(x);
}

template <>
__inline__ double cuGet<double>(float x, float y) {
  return double(x);
}

template <>
__inline__ cuComplex cuGet<cuComplex>(float x, float y) {
  return (make_cuComplex(float(x), float(y)));
}

template <>
__inline__ cuDoubleComplex cuGet<cuDoubleComplex>(float x, float y) {
  return (make_cuDoubleComplex(double(x), double(y)));
}

template <typename T_ELEM>
__inline__ T_ELEM cuGet(double);
template <>
__inline__ float cuGet<float>(double x) {
  return float(x);
}

template <>
__inline__ double cuGet<double>(double x) {
  return double(x);
}

template <>
__inline__ cuComplex cuGet<cuComplex>(double x) {
  return (make_cuComplex(float(x), 0.0f));
}

template <>
__inline__ cuDoubleComplex cuGet<cuDoubleComplex>(double x) {
  return (make_cuDoubleComplex(double(x), 0.0));
}

template <typename T_ELEM>
__inline__ T_ELEM cuGet(double, double);
template <>
__inline__ float cuGet<float>(double x, double y) {
  return float(x);
}

template <>
__inline__ double cuGet<double>(double x, double y) {
  return double(x);
}

template <>
__inline__ cuComplex cuGet<cuComplex>(double x, double y) {
  return (make_cuComplex(float(x), float(y)));
}

template <>
__inline__ cuDoubleComplex cuGet<cuDoubleComplex>(double x, double y) {
  return (make_cuDoubleComplex(double(x), double(y)));
}

static void compress_index(const int *Ind, int nnz, int m, int *Ptr, int base) {
  int i;

  /* initialize everything to zero */
  for (i = 0; i < m + 1; i++) {
    Ptr[i] = 0;
  }
  /* count elements in every row */
  Ptr[0] = base;
  for (i = 0; i < nnz; i++) {
    Ptr[Ind[i] + (1 - base)]++;
  }
  /* add all the values */
  for (i = 0; i < m; i++) {
    Ptr[i + 1] += Ptr[i];
  }
}

struct cooFormat {
  int i;
  int j;
  int p;  // permutation
};

int cmp_cooFormat_csr(struct cooFormat *s, struct cooFormat *t) {
  if (s->i < t->i) {
    return -1;
  } else if (s->i > t->i) {
    return 1;
  } else {
    return s->j - t->j;
  }
}

int cmp_cooFormat_csc(struct cooFormat *s, struct cooFormat *t) {
  if (s->j < t->j) {
    return -1;
  } else if (s->j > t->j) {
    return 1;
  } else {
    return s->i - t->i;
  }
}

typedef int (*FUNPTR)(const void *, const void *);
typedef int (*FUNPTR2)(struct cooFormat *s, struct cooFormat *t);

static FUNPTR2 fptr_array[2] = {
    cmp_cooFormat_csr,
    cmp_cooFormat_csc,
};

static int verify_pattern(int m, int nnz, int *csrRowPtr, int *csrColInd) {
  int i, col, start, end, base_index;
  int error_found = 0;

  if (nnz != (csrRowPtr[m] - csrRowPtr[0])) {
    fprintf(stderr,
            "Error (nnz check failed): (csrRowPtr[%d]=%d - csrRowPtr[%d]=%d) "
            "!= (nnz=%d)\n",
            0, csrRowPtr[0], m, csrRowPtr[m], nnz);
    error_found = 1;
  }

  base_index = csrRowPtr[0];
  if ((0 != base_index) && (1 != base_index)) {
    fprintf(stderr, "Error (base index check failed): base index = %d\n",
            base_index);
    error_found = 1;
  }

  for (i = 0; (!error_found) && (i < m); i++) {
    start = csrRowPtr[i] - base_index;
    end = csrRowPtr[i + 1] - base_index;
    if (start > end) {
      fprintf(
          stderr,
          "Error (corrupted row): csrRowPtr[%d] (=%d) > csrRowPtr[%d] (=%d)\n",
          i, start + base_index, i + 1, end + base_index);
      error_found = 1;
    }
    for (col = start; col < end; col++) {
      if (csrColInd[col] < base_index) {
        fprintf(
            stderr,
            "Error (column vs. base index check failed): csrColInd[%d] < %d\n",
            col, base_index);
        error_found = 1;
      }
      if ((col < (end - 1)) && (csrColInd[col] >= csrColInd[col + 1])) {
        fprintf(stderr,
                "Error (sorting of the column indecis check failed): "
                "(csrColInd[%d]=%d) >= (csrColInd[%d]=%d)\n",
                col, csrColInd[col], col + 1, csrColInd[col + 1]);
        error_found = 1;
      }
    }
  }
  return error_found;
}

template <typename T_ELEM>
int loadMMSparseMatrix(char *filename, char elem_type, bool csrFormat, int *m,
                       int *n, int *nnz, T_ELEM **aVal, int **aRowInd,
                       int **aColInd, int extendSymMatrix) {
  MM_typecode matcode;
  double *tempVal;
  int *tempRowInd, *tempColInd;
  double *tval;
  int *trow, *tcol;
  int *csrRowPtr, *cscColPtr;
  int i, j, error, base, count;
  struct cooFormat *work;

  /* read the matrix */
  error = mm_read_mtx_crd(filename, m, n, nnz, &trow, &tcol, &tval, &matcode);
  if (error) {
    fprintf(stderr, "!!!! can not open file: '%s'\n", filename);
    return 1;
  }

  /* start error checking */
  if (mm_is_complex(matcode) && ((elem_type != 'z') && (elem_type != 'c'))) {
    fprintf(stderr, "!!!! complex matrix requires type 'z' or 'c'\n");
    return 1;
  }

  if (mm_is_dense(matcode) || mm_is_array(matcode) ||
      mm_is_pattern(matcode) /*|| mm_is_integer(matcode)*/) {
    fprintf(
        stderr,
        "!!!! dense, array, pattern and integer matrices are not supported\n");
    return 1;
  }

  /* if necessary symmetrize the pattern (transform from triangular to full) */
  if ((extendSymMatrix) && (mm_is_symmetric(matcode) ||
                            mm_is_hermitian(matcode) || mm_is_skew(matcode))) {
    // count number of non-diagonal elements
    count = 0;
    for (i = 0; i < (*nnz); i++) {
      if (trow[i] != tcol[i]) {
        count++;
      }
    }
    // allocate space for the symmetrized matrix
    tempRowInd = (int *)malloc((*nnz + count) * sizeof(int));
    tempColInd = (int *)malloc((*nnz + count) * sizeof(int));
    if (mm_is_real(matcode) || mm_is_integer(matcode)) {
      tempVal = (double *)malloc((*nnz + count) * sizeof(double));
    } else {
      tempVal = (double *)malloc(2 * (*nnz + count) * sizeof(double));
    }
    // copy the elements regular and transposed locations
    for (j = 0, i = 0; i < (*nnz); i++) {
      tempRowInd[j] = trow[i];
      tempColInd[j] = tcol[i];
      if (mm_is_real(matcode) || mm_is_integer(matcode)) {
        tempVal[j] = tval[i];
      } else {
        tempVal[2 * j] = tval[2 * i];
        tempVal[2 * j + 1] = tval[2 * i + 1];
      }
      j++;
      if (trow[i] != tcol[i]) {
        tempRowInd[j] = tcol[i];
        tempColInd[j] = trow[i];
        if (mm_is_real(matcode) || mm_is_integer(matcode)) {
          if (mm_is_skew(matcode)) {
            tempVal[j] = -tval[i];
          } else {
            tempVal[j] = tval[i];
          }
        } else {
          if (mm_is_hermitian(matcode)) {
            tempVal[2 * j] = tval[2 * i];
            tempVal[2 * j + 1] = -tval[2 * i + 1];
          } else {
            tempVal[2 * j] = tval[2 * i];
            tempVal[2 * j + 1] = tval[2 * i + 1];
          }
        }
        j++;
      }
    }
    (*nnz) += count;
    // free temporary storage
    free(trow);
    free(tcol);
    free(tval);
  } else {
    tempRowInd = trow;
    tempColInd = tcol;
    tempVal = tval;
  }
  // life time of (trow, tcol, tval) is over.
  // please use COO format (tempRowInd, tempColInd, tempVal)

  // use qsort to sort COO format
  work = (struct cooFormat *)malloc(sizeof(struct cooFormat) * (*nnz));
  if (NULL == work) {
    fprintf(stderr, "!!!! allocation error, malloc failed\n");
    return 1;
  }
  for (i = 0; i < (*nnz); i++) {
    work[i].i = tempRowInd[i];
    work[i].j = tempColInd[i];
    work[i].p = i;  // permutation is identity
  }

  if (csrFormat) {
    /* create row-major ordering of indices (sorted by row and within each row
     * by column) */
    qsort(work, *nnz, sizeof(struct cooFormat), (FUNPTR)fptr_array[0]);
  } else {
    /* create column-major ordering of indices (sorted by column and within each
     * column by row) */
    qsort(work, *nnz, sizeof(struct cooFormat), (FUNPTR)fptr_array[1]);
  }

  // (tempRowInd, tempColInd) is sorted either by row-major or by col-major
  for (i = 0; i < (*nnz); i++) {
    tempRowInd[i] = work[i].i;
    tempColInd[i] = work[i].j;
  }

  // setup base
  // check if there is any row/col 0, if so base-0
  // check if there is any row/col equal to matrix dimension m/n, if so base-1
  int base0 = 0;
  int base1 = 0;
  for (i = 0; i < (*nnz); i++) {
    const int row = tempRowInd[i];
    const int col = tempColInd[i];
    if ((0 == row) || (0 == col)) {
      base0 = 1;
    }
    if ((*m == row) || (*n == col)) {
      base1 = 1;
    }
  }
  if (base0 && base1) {
    printf("Error: input matrix is base-0 and base-1 \n");
    return 1;
  }

  base = 0;
  if (base1) {
    base = 1;
  }

  /* compress the appropriate indices */
  if (csrFormat) {
    /* CSR format (assuming row-major format) */
    csrRowPtr = (int *)malloc(((*m) + 1) * sizeof(csrRowPtr[0]));
    if (!csrRowPtr) return 1;
    compress_index(tempRowInd, *nnz, *m, csrRowPtr, base);

    *aRowInd = csrRowPtr;
    *aColInd = (int *)malloc((*nnz) * sizeof(int));
  } else {
    /* CSC format (assuming column-major format) */
    cscColPtr = (int *)malloc(((*n) + 1) * sizeof(cscColPtr[0]));
    if (!cscColPtr) return 1;
    compress_index(tempColInd, *nnz, *n, cscColPtr, base);

    *aColInd = cscColPtr;
    *aRowInd = (int *)malloc((*nnz) * sizeof(int));
  }

  /* transfrom the matrix values of type double into one of the cusparse library
   * types */
  *aVal = (T_ELEM *)malloc((*nnz) * sizeof(T_ELEM));

  for (i = 0; i < (*nnz); i++) {
    if (csrFormat) {
      (*aColInd)[i] = tempColInd[i];
    } else {
      (*aRowInd)[i] = tempRowInd[i];
    }
    if (mm_is_real(matcode) || mm_is_integer(matcode)) {
      (*aVal)[i] = cuGet<T_ELEM>(tempVal[work[i].p]);
    } else {
      (*aVal)[i] =
          cuGet<T_ELEM>(tempVal[2 * work[i].p], tempVal[2 * work[i].p + 1]);
    }
  }

  /* check for corruption */
  int error_found;
  if (csrFormat) {
    error_found = verify_pattern(*m, *nnz, *aRowInd, *aColInd);
  } else {
    error_found = verify_pattern(*n, *nnz, *aColInd, *aRowInd);
  }
  if (error_found) {
    fprintf(stderr, "!!!! verify_pattern failed\n");
    return 1;
  }

  /* cleanup and exit */
  free(work);
  free(tempVal);
  free(tempColInd);
  free(tempRowInd);

  return 0;
}

/* specific instantiation */
template int loadMMSparseMatrix<float>(char *filename, char elem_type,
                                       bool csrFormat, int *m, int *n, int *nnz,
                                       float **aVal, int **aRowInd,
                                       int **aColInd, int extendSymMatrix);

template int loadMMSparseMatrix<double>(char *filename, char elem_type,
                                        bool csrFormat, int *m, int *n,
                                        int *nnz, double **aVal, int **aRowInd,
                                        int **aColInd, int extendSymMatrix);

template int loadMMSparseMatrix<cuComplex>(char *filename, char elem_type,
                                           bool csrFormat, int *m, int *n,
                                           int *nnz, cuComplex **aVal,
                                           int **aRowInd, int **aColInd,
                                           int extendSymMatrix);

template int loadMMSparseMatrix<cuDoubleComplex>(
    char *filename, char elem_type, bool csrFormat, int *m, int *n, int *nnz,
    cuDoubleComplex **aVal, int **aRowInd, int **aColInd, int extendSymMatrix);
